{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 学习率调度器\n",
    ":label:`sec_scheduler`\n",
    "\n",
    "到目前为止，我们主要关注如何更新权重向量的优化算法，而不是它们的更新速率。\n",
    "然而，调整学习率通常与实际算法同样重要，有如下几方面需要考虑：\n",
    "\n",
    "* 首先，学习率的大小很重要。如果它太大，优化就会发散；如果它太小，训练就会需要过长时间，或者我们最终只能得到次优的结果。我们之前看到问题的条件数很重要（有关详细信息，请参见 :numref:`sec_momentum`）。直观地说，这是最不敏感与最敏感方向的变化量的比率。\n",
    "* 其次，衰减速率同样很重要。如果学习率持续过高，我们可能最终会在最小值附近弹跳，从而无法达到最优解。 :numref:`sec_minibatch_sgd`比较详细地讨论了这一点，在 :numref:`sec_sgd`中我们则分析了性能保证。简而言之，我们希望速率衰减，但要比$\\mathcal{O}(t^{-\\frac{1}{2}})$慢，这样能成为解决凸问题的不错选择。\n",
    "* 另一个同样重要的方面是初始化。这既涉及参数最初的设置方式（详情请参阅 :numref:`sec_numerical_stability`），又关系到它们最初的演变方式。这被戏称为*预热*（warmup），即我们最初开始向着解决方案迈进的速度有多快。一开始的大步可能没有好处，特别是因为最初的参数集是随机的。最初的更新方向可能也是毫无意义的。\n",
    "* 最后，还有许多优化变体可以执行周期性学习率调整。这超出了本章的范围，我们建议读者阅读 :cite:`Izmailov.Podoprikhin.Garipov.ea.2018`来了解个中细节。例如，如何通过对整个路径参数求平均值来获得更好的解。\n",
    "\n",
    "鉴于管理学习率需要很多细节，因此大多数深度学习框架都有自动应对这个问题的工具。\n",
    "在本章中，我们将梳理不同的调度策略对准确性的影响，并展示如何通过*学习率调度器*（learning rate scheduler）来有效管理。\n",
    "\n",
    "## 一个简单的问题\n",
    "\n",
    "我们从一个简单的问题开始，这个问题可以轻松计算，但足以说明要义。\n",
    "为此，我们选择了一个稍微现代化的LeNet版本（激活函数使用`relu`而不是`sigmoid`，汇聚层使用最大汇聚层而不是平均汇聚层），并应用于Fashion-MNIST数据集。\n",
    "此外，我们混合网络以提高性能。\n",
    "由于大多数代码都是标准的，我们只介绍基础知识，而不做进一步的详细讨论。如果需要，请参阅 :numref:`chap_cnn`进行复习。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load ../utils/djl-imports\n",
    "%load ../utils/plot-utils\n",
    "%load ../utils/Functions.java\n",
    "%load ../utils/GradDescUtils.java\n",
    "%load ../utils/Accumulator.java\n",
    "%load ../utils/StopWatch.java\n",
    "\n",
    "%load ../utils/Training.java\n",
    "%load ../utils/TrainingChapter11.java"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import ai.djl.basicdataset.cv.classification.*;\n",
    "import org.apache.commons.lang3.ArrayUtils;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "SequentialBlock net = new SequentialBlock();\n",
    "\n",
    "net.add(Conv2d.builder()\n",
    "        .setKernelShape(new Shape(5, 5))\n",
    "        .optPadding(new Shape(2, 2))\n",
    "        .setFilters(1)\n",
    "        .build());\n",
    "net.add(Activation.reluBlock());\n",
    "net.add(Pool.maxPool2dBlock(new Shape(2, 2), new Shape(2, 2)));\n",
    "net.add(Conv2d.builder()\n",
    "        .setKernelShape(new Shape(5, 5))\n",
    "        .setFilters(1)\n",
    "        .build());\n",
    "net.add(Blocks.batchFlattenBlock());\n",
    "net.add(Activation.reluBlock());\n",
    "net.add(Linear.builder().setUnits(120).build());\n",
    "net.add(Activation.reluBlock());\n",
    "net.add(Linear.builder().setUnits(84).build());\n",
    "net.add(Activation.reluBlock());\n",
    "net.add(Linear.builder().setUnits(10).build());"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "int batchSize = 256;\n",
    "RandomAccessDataset trainDataset = FashionMnist.builder()\n",
    "        .optUsage(Dataset.Usage.TRAIN)\n",
    "        .setSampling(batchSize, false)\n",
    "        .optLimit(Long.getLong(\"DATASET_LIMIT\", Long.MAX_VALUE))\n",
    "        .build();\n",
    "\n",
    "RandomAccessDataset testDataset = FashionMnist.builder()\n",
    "        .optUsage(Dataset.Usage.TEST)\n",
    "        .setSampling(batchSize, false)\n",
    "        .optLimit(Long.getLong(\"DATASET_LIMIT\", Long.MAX_VALUE))\n",
    "        .build();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "double[] trainLoss;\n",
    "double[] testAccuracy;\n",
    "double[] epochCount;\n",
    "double[] trainAccuracy;\n",
    "\n",
    "public static void train(RandomAccessDataset trainIter, RandomAccessDataset testIter,\n",
    "                             int numEpochs, Trainer trainer) throws IOException, TranslateException {\n",
    "    epochCount = new double[numEpochs];\n",
    "\n",
    "    for (int i = 0; i < epochCount.length; i++) {\n",
    "        epochCount[i] = (i + 1);\n",
    "    }\n",
    "\n",
    "    double avgTrainTimePerEpoch = 0;\n",
    "    Map<String, double[]> evaluatorMetrics = new HashMap<>();\n",
    "\n",
    "    trainer.setMetrics(new Metrics());\n",
    "\n",
    "    EasyTrain.fit(trainer, numEpochs, trainIter, testIter);\n",
    "\n",
    "    Metrics metrics = trainer.getMetrics();\n",
    "\n",
    "    trainer.getEvaluators().stream()\n",
    "            .forEach(evaluator -> {\n",
    "                evaluatorMetrics.put(\"train_epoch_\" + evaluator.getName(), metrics.getMetric(\"train_epoch_\" + evaluator.getName()).stream()\n",
    "                        .mapToDouble(x -> x.getValue().doubleValue()).toArray());\n",
    "                evaluatorMetrics.put(\"validate_epoch_\" + evaluator.getName(), metrics.getMetric(\"validate_epoch_\" + evaluator.getName()).stream()\n",
    "                        .mapToDouble(x -> x.getValue().doubleValue()).toArray());\n",
    "            });\n",
    "\n",
    "    avgTrainTimePerEpoch = metrics.mean(\"epoch\");\n",
    "\n",
    "    trainLoss = evaluatorMetrics.get(\"train_epoch_SoftmaxCrossEntropyLoss\");\n",
    "    trainAccuracy = evaluatorMetrics.get(\"train_epoch_Accuracy\");\n",
    "    testAccuracy = evaluatorMetrics.get(\"validate_epoch_Accuracy\");\n",
    "\n",
    "    System.out.printf(\"loss %.3f,\" , trainLoss[numEpochs-1]);\n",
    "    System.out.printf(\" train acc %.3f,\" , trainAccuracy[numEpochs-1]);\n",
    "    System.out.printf(\" test acc %.3f\\n\" , testAccuracy[numEpochs-1]);\n",
    "    System.out.printf(\"%.1f examples/sec \\n\", trainIter.size() / (avgTrainTimePerEpoch / Math.pow(10, 9)));\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "让我们来看看如果使用默认设置，调用此算法会发生什么。\n",
    "例如设学习率为$0.3$并训练$30$次迭代。\n",
    "留意在超过了某点、测试准确度方面的进展停滞时，训练准确度将如何继续提高。\n",
    "两条曲线之间的间隙表示过拟合。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "float lr = 0.3f;\n",
    "int numEpochs = Integer.getInteger(\"MAX_EPOCH\", 10);\n",
    "\n",
    "Model model = Model.newInstance(\"Modern LeNet\");\n",
    "model.setBlock(net);\n",
    "\n",
    "Loss loss = Loss.softmaxCrossEntropyLoss();\n",
    "Tracker lrt = Tracker.fixed(lr);\n",
    "Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build();\n",
    "\n",
    "DefaultTrainingConfig config = new DefaultTrainingConfig(loss)\n",
    "        .optOptimizer(sgd) // Optimizer\n",
    "        .addEvaluator(new Accuracy()) // Model Accuracy\n",
    "        .addTrainingListeners(TrainingListener.Defaults.logging()); // Logging\n",
    "\n",
    "Trainer trainer = model.newTrainer(config);\n",
    "trainer.initialize(new Shape(1, 1, 28, 28));\n",
    "\n",
    "train(trainDataset, testDataset, numEpochs, trainer);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public void plotMetrics() {\n",
    "    String[] lossLabel = new String[trainLoss.length + testAccuracy.length + trainAccuracy.length];\n",
    "\n",
    "    Arrays.fill(lossLabel, 0, trainLoss.length, \"train loss\");\n",
    "    Arrays.fill(lossLabel, trainAccuracy.length, trainLoss.length + trainAccuracy.length, \"train acc\");\n",
    "    Arrays.fill(lossLabel, trainLoss.length + trainAccuracy.length,\n",
    "                    trainLoss.length + testAccuracy.length + trainAccuracy.length, \"test acc\");\n",
    "\n",
    "    Table data = Table.create(\"Data\").addColumns(\n",
    "        DoubleColumn.create(\"epoch\", ArrayUtils.addAll(epochCount, ArrayUtils.addAll(epochCount, epochCount))),\n",
    "        DoubleColumn.create(\"metrics\", ArrayUtils.addAll(trainLoss, ArrayUtils.addAll(trainAccuracy, testAccuracy))),\n",
    "        StringColumn.create(\"lossLabel\", lossLabel)\n",
    "    );\n",
    "\n",
    "    display(LinePlot.create(\"\", data, \"epoch\", \"metrics\", \"lossLabel\"));\n",
    "}\n",
    "\n",
    "plotMetrics();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 学习率调度器\n",
    "\n",
    "我们可以在每个迭代轮数（甚至在每个小批量）之后向下调整学习率。\n",
    "例如，以动态的方式来响应优化的进展情况。\n",
    "更通常而言，我们应该定义一个调度器。\n",
    "当调用更新次数时，它将返回学习率的适当值。\n",
    "让我们定义一个简单的方法，将学习率设置为$\\eta = \\eta_0 (t + 1)^{-\\frac{1}{2}}$。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public class SquareRootTracker {\n",
    "    float lr;\n",
    "    public SquareRootTracker() {\n",
    "        this(0.1f);\n",
    "    }\n",
    "    public SquareRootTracker(float learningRate) {\n",
    "        this.lr = learningRate;\n",
    "    }\n",
    "    public float getNewLearningRate(int numUpdate) {\n",
    "        return lr * (float) Math.pow(numUpdate + 1, -0.5);\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "让我们在一系列值上绘制它的行为。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public Figure plotLearningRate(int[] epochs, float[] learningRates) {\n",
    "    \n",
    "    String[] lossLabel = new String[trainLoss.length + testAccuracy.length + trainAccuracy.length];\n",
    "\n",
    "    Arrays.fill(lossLabel, 0, trainLoss.length, \"train loss\");\n",
    "    Arrays.fill(lossLabel, trainAccuracy.length, trainLoss.length + trainAccuracy.length, \"train acc\");\n",
    "    Arrays.fill(lossLabel, trainLoss.length + trainAccuracy.length,\n",
    "                    trainLoss.length + testAccuracy.length + trainAccuracy.length, \"test acc\");\n",
    "\n",
    "    Table data = Table.create(\"Data\").addColumns(\n",
    "                IntColumn.create(\"epoch\", epochs),\n",
    "                DoubleColumn.create(\"learning rate\", learningRates)\n",
    "    );\n",
    "\n",
    "    return LinePlot.create(\"Learning Rate vs. Epoch\", data, \"epoch\", \"learning rate\");\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "SquareRootTracker tracker = new SquareRootTracker();\n",
    "\n",
    "int[] epochs = new int[numEpochs];\n",
    "float[] learningRates = new float[numEpochs];\n",
    "for (int i = 0; i < numEpochs; i++) {\n",
    "    epochs[i] = i;\n",
    "    learningRates[i] = tracker.getNewLearningRate(i);\n",
    "}\n",
    "\n",
    "plotLearningRate(epochs, learningRates);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "现在让我们来看看这对在Fashion-MNIST数据集上的训练有何影响。\n",
    "我们只是提供调度器作为训练算法的额外参数。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这比以前好一些：曲线比以前更加平滑，并且过拟合更小了。\n",
    "遗憾的是，关于为什么在理论上某些策略会导致较轻的过拟合，有一些观点认为，较小的步长将导致参数更接近零，因此更简单。\n",
    "但是，这并不能完全解释这种现象，因为我们并没有真正地提前停止，而只是轻柔地降低了学习率。\n",
    "\n",
    "## 策略\n",
    "\n",
    "虽然我们不可能涵盖所有类型的学习率调度器，但我们会尝试在下面简要概述常用的策略：多项式衰减和分段常数表。\n",
    "此外，余弦学习率调度在实践中的一些问题上运行效果很好。\n",
    "在某些问题上，最好在使用较高的学习率之前预热优化器。\n",
    "\n",
    "### 多因子调度器\n",
    "\n",
    "多项式衰减的一种替代方案是乘法衰减，即$\\eta_{t+1} \\leftarrow \\eta_t \\cdot \\alpha$其中$\\alpha \\in (0, 1)$。为了防止学习率衰减超出合理的下限，更新方程经常修改为$\\eta_{t+1} \\leftarrow \\mathop{\\mathrm{max}}(\\eta_{\\mathrm{min}}, \\eta_t \\cdot \\alpha)$。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public class DemoFactorTracker {\n",
    "    float baseLr;\n",
    "    float stopFactorLr;\n",
    "    float factor;\n",
    "    public DemoFactorTracker(float factor, float stopFactorLr, float baseLr) {\n",
    "        this.factor = factor;\n",
    "        this.stopFactorLr = stopFactorLr;\n",
    "        this.baseLr = baseLr;\n",
    "    }\n",
    "    public DemoFactorTracker() {\n",
    "        this(1f, (float) 1e-7, 0.1f);\n",
    "    }\n",
    "    public float getNewLearningRate(int numUpdate) {\n",
    "        return lr * (float) Math.pow(numUpdate + 1, -0.5);\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "DemoFactorTracker tracker = new DemoFactorTracker(0.9f, (float) 1e-2, 2);\n",
    "\n",
    "numEpochs = 50;\n",
    "int[] epochs = new int[numEpochs];\n",
    "float[] learningRates = new float[numEpochs];\n",
    "for (int i = 0; i < numEpochs; i++) {\n",
    "    epochs[i] = i;\n",
    "    learningRates[i] = tracker.getNewLearningRate(i);\n",
    "}\n",
    "\n",
    "plotLearningRate(epochs, learningRates);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "接下来，我们将使用内置的调度器，但在这里仅解释它们的功能。\n",
    "\n",
    "### 多因子调度器\n",
    "\n",
    "训练深度网络的常见策略之一是保持分段稳定的学习率，并且每隔一段时间就一定程度学习率降低。\n",
    "具体地说，给定一组降低学习率的时间，例如$s = \\{5, 10, 20\\}$每当$t \\in s$时降低$\\eta_{t+1} \\leftarrow \\eta_t \\cdot \\alpha$。\n",
    "假设每步中的值减半，我们可以按如下方式实现这一点。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "MultiFactorTracker tracker = Tracker.multiFactor()\n",
    "        .setSteps(new int[]{5, 30})\n",
    "        .optFactor(0.5f)\n",
    "        .setBaseValue(0.5f)\n",
    "        .build();\n",
    "\n",
    "numEpochs = 10;\n",
    "int[] epochs = new int[numEpochs];\n",
    "float[] learningRates = new float[numEpochs];\n",
    "for (int i = 0; i < numEpochs; i++) {\n",
    "    epochs[i] = i;\n",
    "    learningRates[i] = tracker.getNewValue(i);\n",
    "}\n",
    "\n",
    "plotLearningRate(epochs, learningRates);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这种分段恒定学习率调度背后的直觉是，让优化持续进行，直到权重向量的分布达到一个驻点。\n",
    "此时，我们才将学习率降低，以获得更高质量的代理来达到一个良好的局部最小值。\n",
    "下面的例子展示了如何使用这种方法产生更好的解决方案。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "int numEpochs = Integer.getInteger(\"MAX_EPOCH\", 10);\n",
    "\n",
    "Model model = Model.newInstance(\"Modern LeNet\");\n",
    "model.setBlock(net);\n",
    "\n",
    "Loss loss = Loss.softmaxCrossEntropyLoss();\n",
    "Optimizer sgd = Optimizer.sgd().setLearningRateTracker(tracker).build();\n",
    "\n",
    "DefaultTrainingConfig config = new DefaultTrainingConfig(loss)\n",
    "        .optOptimizer(sgd) // Optimizer\n",
    "        .addEvaluator(new Accuracy()) // Model Accuracy\n",
    "        .addTrainingListeners(TrainingListener.Defaults.logging()); // Logging\n",
    "\n",
    "Trainer trainer = model.newTrainer(config);\n",
    "trainer.initialize(new Shape(1, 1, 28, 28));\n",
    "\n",
    "train(trainDataset, testDataset, numEpochs, trainer);\n",
    "plotMetrics();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "String[] lossLabel = new String[trainLoss.length + testAccuracy.length + trainAccuracy.length];\n",
    "\n",
    "Arrays.fill(lossLabel, 0, trainLoss.length, \"train loss\");\n",
    "Arrays.fill(lossLabel, trainAccuracy.length, trainLoss.length + trainAccuracy.length, \"train acc\");\n",
    "Arrays.fill(lossLabel, trainLoss.length + trainAccuracy.length,\n",
    "                trainLoss.length + testAccuracy.length + trainAccuracy.length, \"test acc\");\n",
    "\n",
    "Table data = Table.create(\"Data\").addColumns(\n",
    "            DoubleColumn.create(\"epoch\", ArrayUtils.addAll(epochCount, ArrayUtils.addAll(epochCount, epochCount))),\n",
    "            DoubleColumn.create(\"metrics\", ArrayUtils.addAll(trainLoss, ArrayUtils.addAll(trainAccuracy, testAccuracy))),\n",
    "            StringColumn.create(\"lossLabel\", lossLabel)\n",
    ");\n",
    "\n",
    "LinePlot.create(\"\", data, \"epoch\", \"metrics\", \"lossLabel\");"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 余弦调度器\n",
    "\n",
    "余弦调度器是 :cite:`Loshchilov.Hutter.2016`提出的一种启发式算法。\n",
    "它所依据的观点是：我们可能不想在一开始就太大地降低学习率，而且可能希望最终能用非常小的学习率来“改进”解决方案。\n",
    "这产生了一个类似于余弦的调度，函数形式如下所示，学习率的值在$t \\in [0, T]$之间。\n",
    "\n",
    "$$\\eta_t = \\eta_T + \\frac{\\eta_0 - \\eta_T}{2} \\left(1 + \\cos(\\pi t/T)\\right)$$\n",
    "\n",
    "这里$\\eta_0$是初始学习率，$\\eta_T$是当$T$时的目标学习率。\n",
    "此外，对于$t > T$，我们只需将值固定到$\\eta_T$而不再增加它。\n",
    "在下面的示例中，我们设置了最大更新步数$T = 20$。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public class DemoCosineTracker {\n",
    "    float baseLr;\n",
    "    float finalLr;\n",
    "    int maxUpdate;\n",
    "    public DemoCosineTracker() {\n",
    "        this(0.5f, 0.01f, 20);\n",
    "    }\n",
    "    public DemoCosineTracker(float baseLr, float finalLr, int maxUpdate) {\n",
    "        this.baseLr = baseLr;\n",
    "        this.finalLr = finalLr;\n",
    "        this.maxUpdate = maxUpdate;\n",
    "    }\n",
    "    public float getNewLearningRate(int numUpdate) {\n",
    "        if (numUpdate > maxUpdate) {\n",
    "            return finalLr;\n",
    "        }\n",
    "        // Scale the curve to smoothly transition\n",
    "        float step = (baseLr - finalLr) / 2 * (1 + (float) Math.cos(Math.PI * numUpdate / maxUpdate));\n",
    "        return finalLr + step;\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "DemoCosineTracker tracker = new DemoCosineTracker(0.5f, 0.01f, 20);\n",
    "\n",
    "int[] epochs = new int[numEpochs];\n",
    "float[] learningRates = new float[numEpochs];\n",
    "for (int i = 0; i < numEpochs; i++) {\n",
    "    epochs[i] = i;\n",
    "    learningRates[i] = tracker.getNewLearningRate(i);\n",
    "}\n",
    "\n",
    "plotLearningRate(epochs, learningRates);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在计算机视觉中，这个调度可以引出改进的结果。\n",
    "但请注意，如下所示，这种改进并不能保证。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "CosineTracker cosineTracker = Tracker.cosine()\n",
    "                            .setBaseValue(0.5f)\n",
    "                            .optFinalValue(0.01f)\n",
    "                            .setMaxUpdates(20)\n",
    "                            .build();\n",
    "\n",
    "Loss loss = Loss.softmaxCrossEntropyLoss();\n",
    "Optimizer sgd = Optimizer.sgd().setLearningRateTracker(cosineTracker).build();\n",
    "\n",
    "DefaultTrainingConfig config = new DefaultTrainingConfig(loss)\n",
    "        .optOptimizer(sgd) // Optimizer\n",
    "        .addEvaluator(new Accuracy()) // Model Accuracy\n",
    "        .addTrainingListeners(TrainingListener.Defaults.logging()); // Logging\n",
    "\n",
    "Trainer trainer = model.newTrainer(config);\n",
    "trainer.initialize(new Shape(1, 1, 28, 28));\n",
    "\n",
    "train(trainDataset, testDataset, numEpochs, trainer);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 预热\n",
    "\n",
    "在某些情况下，初始化参数不足以得到良好的解。\n",
    "这对于某些高级网络设计来说尤其棘手，可能导致不稳定的优化结果。\n",
    "对此，一方面，我们可以选择一个足够小的学习率，\n",
    "从而防止一开始发散，然而这样进展太缓慢。\n",
    "另一方面，较高的学习率最初就会导致发散。\n",
    "\n",
    "解决这种困境的一个相当简单的解决方法是使用预热期，在此期间学习率将增加至初始最大值，然后冷却直到优化过程结束。\n",
    "为了简单起见，通常使用线性递增。\n",
    "这引出了如下表所示的时间表。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public class CosineWarmupTracker {\n",
    "    float baseLr;\n",
    "    float finalLr;\n",
    "    int maxUpdate;\n",
    "    int warmUpSteps;\n",
    "    float warmUpBeginValue;\n",
    "    float warmUpFinalValue;\n",
    "    \n",
    "    public CosineWarmupTracker() {\n",
    "        this(0.5f, 0.01f, 20, 5);\n",
    "    }\n",
    "    \n",
    "    public CosineWarmupTracker(float baseLr, float finalLr, int maxUpdate, int warmUpSteps) {\n",
    "        this.baseLr = baseLr;\n",
    "        this.finalLr = finalLr;\n",
    "        this.maxUpdate = maxUpdate;\n",
    "        this.warmUpSteps = 5;\n",
    "        this.warmUpBeginValue = 0f;\n",
    "    }\n",
    "    \n",
    "    public float getNewLearningRate(int numUpdate) {\n",
    "        if (numUpdate <= warmUpSteps) {\n",
    "            return getWarmUpValue(numUpdate);\n",
    "        }\n",
    "        if (numUpdate > maxUpdate) {\n",
    "            return finalLr;\n",
    "        }\n",
    "        // Scale the cosine curve to fit smoothly with the warmup steps\n",
    "        float step = (baseLr - finalLr) / 2 * (1 + \n",
    "            (float) Math.cos(Math.PI * (numUpdate - warmUpSteps) / (maxUpdate - warmUpSteps)));\n",
    "        return finalLr + step;\n",
    "    }\n",
    "    \n",
    "    public float getWarmUpValue(int numUpdate) {\n",
    "        // Linear warmup\n",
    "        return warmUpBeginValue + (baseLr - warmUpBeginValue) * numUpdate / warmUpSteps;\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "CosineWarmupTracker tracker = new CosineWarmupTracker(0.5f, 0.01f, 20, 5);\n",
    "\n",
    "int[] epochs = new int[numEpochs];\n",
    "float[] learningRates = new float[numEpochs];\n",
    "for (int i = 0; i < numEpochs; i++) {\n",
    "    epochs[i] = i;\n",
    "    learningRates[i] = tracker.getNewLearningRate(i);\n",
    "}\n",
    "\n",
    "plotLearningRate(epochs, learningRates);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "注意，观察前5个迭代轮数的性能，网络最初收敛得更好。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "CosineTracker cosineTracker = Tracker.cosine()\n",
    "        .setBaseValue(0.5f)\n",
    "        .optFinalValue(0.01f)\n",
    "        .setMaxUpdates(15)\n",
    "        .build();\n",
    "\n",
    "WarmUpTracker warmupCosine = Tracker.warmUp()\n",
    "        .optWarmUpSteps(5)\n",
    "        .setMainTracker(cosineTracker)\n",
    "        .build();\n",
    "\n",
    "Loss loss = Loss.softmaxCrossEntropyLoss();\n",
    "Optimizer sgd = Optimizer.sgd().setLearningRateTracker(warmupCosine).build();\n",
    "\n",
    "DefaultTrainingConfig config = new DefaultTrainingConfig(loss)\n",
    "        .optOptimizer(sgd) // Optimizer\n",
    "        .addEvaluator(new Accuracy()) // Model Accuracy\n",
    "        .addTrainingListeners(TrainingListener.Defaults.logging()); // Logging\n",
    "\n",
    "Trainer trainer = model.newTrainer(config);\n",
    "trainer.initialize(new Shape(1, 1, 28, 28));\n",
    "\n",
    "train(trainDataset, testDataset, numEpochs, trainer);\n",
    "plotMetrics();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "预热可以应用于任何调度器，而不仅仅是余弦。\n",
    "有关学习率调度的更多实验和更详细讨论，请参阅 :cite:`Gotmare.Keskar.Xiong.ea.2018`。\n",
    "其中，这篇论文的点睛之笔的发现：预热阶段限制了非常深的网络中参数的发散量。\n",
    "这在直觉上是有道理的：在网络中那些一开始花费最多时间取得进展的部分，随机初始化会产生巨大的发散。\n",
    "\n",
    "## 小结\n",
    "\n",
    "* 在训练期间逐步降低学习率可以提高准确性，并且减少模型的过拟合。\n",
    "* 在实验中，每当进展趋于稳定时就降低学习率，这是很有效的。从本质上说，这可以确保我们有效地收敛到一个适当的解，也只有这样才能通过降低学习率来减小参数的固有方差。\n",
    "* 余弦调度器在某些计算机视觉问题中很受欢迎。\n",
    "* 优化之前的预热期可以防止发散。\n",
    "* 优化在深度学习中有多种用途。对于同样的训练误差而言，选择不同的优化算法和学习率调度，除了最大限度地减少训练时间，可以导致测试集上不同的泛化和过拟合量。\n",
    "\n",
    "## 练习\n",
    "\n",
    "1. 试验给定固定学习率的优化行为。这种情况下你可以获得的最佳模型是什么？\n",
    "1. 如果你改变学习率下降的指数，收敛性会如何改变？在实验中方便起见，使用`PolyScheduler`。\n",
    "1. 将余弦调度器应用于大型计算机视觉问题，例如训练ImageNet数据集。与其他调度器相比，它如何影响性能？\n",
    "1. 预热应该持续多长时间？\n",
    "1. 你能把优化和采样联系起来吗？首先，在随机梯度朗之万动力学上使用 :cite:`Welling.Teh.2011`的结果。\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Java",
   "language": "java",
   "name": "java"
  },
  "language_info": {
   "codemirror_mode": "java",
   "file_extension": ".jshell",
   "mimetype": "text/x-java-source",
   "name": "Java",
   "pygments_lexer": "java",
   "version": "14.0.2+12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
